{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "97qSvXYr2BH6"
      },
      "source": [
        "This Colab is a tool to visualize the 3D NeRF and Semantic Scene Representations produced as a part of NeSF: Neural Semantic Fields\n",
        "\n",
        "\n",
        "The project website for NeSF can be found here: https://nesf3d.github.io/\n",
        "\n",
        "\n",
        "Accompanying code can be found on GitHub at: https://github.com/google-research/jax3d/tree/main/jax3d/projects/nesf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tzSzNG0M8V4d"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NrfRlUAR3B5R"
      },
      "outputs": [],
      "source": [
        "# @title Set up environment\n",
        "import sys\n",
        "!git clone https://github.com/google-research/jax3d.git\n",
        "%cd /content/jax3d\n",
        "!python -m pip install --upgrade pip\n",
        "!pip install .\n",
        "!pip install --upgrade \"jax3d[nesf]\"\n",
        "!pip install --upgrade \"jax[cpu]\"\n",
        "!pip install flax==0.5.3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xRkf_wpM3ReK"
      },
      "outputs": [],
      "source": [
        "# @title Configure datasets and checkpoints\n",
        "!wget https://storage.googleapis.com/kubric-public/data/NeSFDatasets/NeSF%20datasets/klevr.tar.gz\n",
        "!tar -xvf klevr.tar.gz\n",
        "!rm klevr.tar.gz\n",
        "!wget https://storage.googleapis.com/kubric-public/data/NeSFDatasets/NeRF%20checkpoints/klevr.tar.gz\n",
        "!mkdir klevr_checkpoints\n",
        "!mv klevr.tar.gz klevr_checkpoints\n",
        "%cd klevr_checkpoints/\n",
        "!tar -xvf klevr.tar.gz\n",
        "!wget https://storage.googleapis.com/kubric-public/data/NeSFDatasets/NeSFCheckpoints/klevr.tar.gz\n",
        "!mkdir klevr_semantic_checkpoints\n",
        "!mv klevr.tar.gz klevr_semantic_checkpoints\n",
        "%cd klevr_semantic_checkpoints/\n",
        "!tar -xvf klevr.tar.gz"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iAln2tMq3l-K"
      },
      "outputs": [],
      "source": [
        "# @title Reorganize code\n",
        "%cd /content\n",
        "!mv jax3d jax3d_old\n",
        "!mv /content/jax3d_old/* /content\n",
        "!rm -R /content/jax3d_old\n",
        "%cd /content"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nAilt0wF7hZb"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WFPqg2yW_Bci"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "\n",
        "import sys\n",
        "import chex\n",
        "import flax\n",
        "import imageio\n",
        "import plotly.graph_objects as go\n",
        "import numpy as np\n",
        "import scipy\n",
        "import sklearn\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy\n",
        "import jax\n",
        "import pandas as pd\n",
        "import plotly.express as px\n",
        "from jax import numpy as jnp\n",
        "from flax import linen as nn\n",
        "import tensorflow as tf\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "llw3azyR_GC4"
      },
      "outputs": [],
      "source": [
        "import gin\n",
        "gin.enter_interactive_mode()  # Avoid error when reloading modules\n",
        "\n",
        "import jax3d.projects.nesf as j3d\n",
        "from jax3d.projects.nesf import nerfstatic as nf\n",
        "\n",
        "from jax3d.projects.nesf.nerfstatic.utils import train_utils\n",
        "from jax3d.projects.nesf.nerfstatic.utils import eval_utils\n",
        "from jax3d.projects.nesf.nerfstatic import datasets\n",
        "from jax3d.projects.nesf.nerfstatic.datasets import klevr\n",
        "from jax3d.projects.nesf.nerfstatic.models import models\n",
        "from jax3d.projects.nesf.nerfstatic.models import model_utils\n",
        "from jax3d.projects.nesf.nerfstatic.nerf import utils\n",
        "from jax3d.projects.nesf.nerfstatic.utils import config as nerf_config\n",
        "from jax3d.projects.nesf.nerfstatic.utils import types\n",
        "from jax3d.projects.nesf.nerfstatic.utils import semantic_utils\n",
        "\n",
        "import importlib\n",
        "importlib.reload(j3d)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v9S8jmTN7m4K"
      },
      "source": [
        "# Load Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ORBjnzj8EBUO"
      },
      "outputs": [],
      "source": [
        "GIN_CONFIG = \"\"\"\n",
        "ConfigParams.models = \"NerfParams\"\n",
        "\n",
        "DatasetParams.batch_size = 16384\n",
        "DatasetParams.data_dir = '/content/klevr'\n",
        "TrainParams.nerf_model_ckpt = '/content/klevr_checkpoints'\n",
        "DatasetParams.eval_scenes = '0:1'\n",
        "DatasetParams.novel_scenes = '80:81'\n",
        "DatasetParams.train_scenes = '0:1'\n",
        "\n",
        "DatasetParams.dataset = 'klevr'\n",
        "DatasetParams.factor = 0\n",
        "DatasetParams.num_scenes_per_batch = 1\n",
        "DatasetParams.max_num_train_images_per_scene = 9\n",
        "DatasetParams.max_num_test_images_per_scene = 4\n",
        "\n",
        "ModelParams.num_semantic_classes = 6  # Needed to make semantic predictions.\n",
        "\n",
        "TrainParams.mode = \"SEMANTIC\"\n",
        "TrainParams.print_every = 100\n",
        "TrainParams.train_dir = \"/content/klevr_semantic_checkpoints/klevr/\"  # Will be overriden by XManager.\n",
        "TrainParams.train_steps = 25000\n",
        "TrainParams.save_every = 500\n",
        "TrainParams.semantic_smoothness_regularization_num_points_per_device = 8192\n",
        "TrainParams.semantic_smoothness_regularization_weight = 0.01\n",
        "TrainParams.semantic_smoothness_regularization_stddev = 0.05\n",
        "TrainParams.nerf_model_recompute_sigma_grid = True\n",
        "TrainParams.nerf_model_recompute_sigma_grid_shape = (64, 64, 64)\n",
        "TrainParams.nerf_model_recompute_sigma_grid_convert_sigma_to_density = True\n",
        "\n",
        "EvalParams.chunk = 32788\n",
        "EvalParams.eval_num_log_images = 8\n",
        "EvalParams.eval_once = False\n",
        "\n",
        "ModelParams.unet_depth = 3\n",
        "ModelParams.unet_feature_size = (32, 64, 128, 256)\n",
        "ModelParams.num_fine_samples = 192\n",
        "ModelParams.apply_random_scene_rotations = True\n",
        "\"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IgcuQ3AaOMeg"
      },
      "outputs": [],
      "source": [
        "from absl import app\n",
        "\n",
        "# Addresses `UnrecognizedFlagError: Unknown command line flag 'f'`\n",
        "sys.argv = sys.argv[:1]\n",
        "\n",
        "# `app.run` calls `sys.exit`\n",
        "try:\n",
        "  app.run(lambda argv: None)\n",
        "except:\n",
        "  pass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mXzqN7Z-EEuG"
      },
      "outputs": [],
      "source": [
        "# Load experiment HParams\n",
        "\n",
        "gin.clear_config()\n",
        "gin.parse_config(GIN_CONFIG)\n",
        "params = nerf_config.root_config_from_flags()\n",
        "\n",
        "params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mzjL7-wpO1c1"
      },
      "outputs": [],
      "source": [
        "# Load Dataset\n",
        "\n",
        "DATA_DIR_ROOT = params.datasets.data_dir\n",
        "SCENE_ID = params.datasets.train_scenes.split(':')[0]\n",
        "DATA_DIR = DATA_DIR_ROOT / SCENE_ID"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5eriVjeQQdJN"
      },
      "outputs": [],
      "source": [
        "rng = j3d.RandomState(params.train.random_seed)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_kAr6wcIr9go"
      },
      "outputs": [],
      "source": [
        "dataset = datasets.get_dataset(\n",
        "    split=\"train\",\n",
        "    args=params.datasets,\n",
        "    model_args=params.models,\n",
        "    example_type=datasets.ExampleType.RAY,\n",
        "    ds_state=None,\n",
        "    is_novel_scenes=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xf-g-Mx5sHBF"
      },
      "outputs": [],
      "source": [
        "_, placeholder_batch = dataset.peek()\n",
        "placeholder_batch = jax.tree_map(lambda t: t[0, 0, ...], placeholder_batch)\n",
        "print(placeholder_batch.target_view.rays.scene_id.shape)\n",
        "print('scene_name:', dataset.all_metadata[0].scene_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PuiX4A51rSqv"
      },
      "outputs": [],
      "source": [
        "recompute_sigma_grid_opts = semantic_utils.RecomputeSigmaGridOptions.from_params(params.train)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VRr7uiRqtnOZ"
      },
      "outputs": [],
      "source": [
        "# Initialize \u0026 load per-scene NeRF models.\n",
        "\n",
        "recovered_nerf_state = semantic_utils.load_all_nerf_variables(\n",
        "    save_dir=params.train.nerf_model_ckpt,\n",
        "    train_dataset=dataset,\n",
        "    novel_dataset=None,\n",
        "    recompute_sigma_grid_opts=recompute_sigma_grid_opts\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6I3AjwlOxm0B"
      },
      "outputs": [],
      "source": [
        "# Select pretrained NeRF corresponding to scene 0.\n",
        "\n",
        "# scene_id corresponding to ray=0.\n",
        "scene_id = placeholder_batch.target_view.rays.scene_id[0, 0]\n",
        "print('scene_id:', scene_id)\n",
        "\n",
        "nerf_variables = semantic_utils.select_and_stack([scene_id],\n",
        "                                                  recovered_nerf_state.train_variables,\n",
        "                                                  num_devices=1)\n",
        "nerf_sigma_grid = semantic_utils.select_and_stack([scene_id],\n",
        "                                                  recovered_nerf_state.train_sigma_grids,\n",
        "                                                  num_devices=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EcUpUSQDx7dK"
      },
      "outputs": [],
      "source": [
        "# Extract NeRF state corresponding to device=0, scene=0.\n",
        "\n",
        "nerf_variables = jax.tree_map(lambda x: x[0, 0], nerf_variables)\n",
        "nerf_sigma_grid = jax.tree_map(lambda x: x[0, 0], nerf_sigma_grid)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wpETF73Kx_3I"
      },
      "outputs": [],
      "source": [
        "# Drop the first dimension of nerf_sigma_grid if necessary.\n",
        "\n",
        "if len(nerf_sigma_grid.shape) == 5:\n",
        "  assert nerf_sigma_grid.shape[0] == 1\n",
        "  nerf_sigma_grid = nerf_sigma_grid[0]\n",
        "nerf_sigma_grid.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L1BZOGXpoLdG"
      },
      "source": [
        "# Visualize Sigma Field"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wAMFqOGc2a5G"
      },
      "outputs": [],
      "source": [
        "# Plot XYZ values where density \u003e min_density for various values of min_density.\n",
        "\n",
        "def plot_density_coordinates(kept_points, color=None, ax=None):\n",
        "  ii, jj, kk = kept_points[:, 0], kept_points[:, 1], kept_points[:, 2]\n",
        "  if ax is None:\n",
        "    ax = plt.axes(projection='3d')\n",
        "\n",
        "  c = kk\n",
        "  if color is not None:\n",
        "    c = color\n",
        "  ax.scatter(ii, jj, kk, c=c, s=1, cmap='viridis', linewidth=1);\n",
        "  ax.set_xlabel('x')\n",
        "  ax.set_ylabel('y')\n",
        "  ax.set_zlabel('z')\n",
        "  return ax\n",
        "\n",
        "def plot_density_coordinates_min_density(nerf_sigma_grid, eligible_points, min_density_values):\n",
        "  fig = plt.figure(figsize=(len(min_density_values) * 4, 4))\n",
        "  axs = []\n",
        "  for i, min_density in enumerate(min_density_values):\n",
        "    kept_points_sigma_grid = binarize_sigma_grid(nerf_sigma_grid, eligible_points, min_density)\n",
        "    ax = fig.add_subplot(1, len(min_density_values), i+1, projection='3d')\n",
        "    ax = plot_density_coordinates(kept_points_sigma_grid, ax=ax)\n",
        "    ax.set_title(f'min_density = {min_density}')\n",
        "    axs.append(ax)\n",
        "  return fig, axs\n",
        "\n",
        "\n",
        "MIN_DENSITY_VALUES = [0, 2, 4, 8, 16, 32, 64]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mWWvF0NV20vN"
      },
      "outputs": [],
      "source": [
        "# Plot XYZ values where density \u003e 0, interactively.\n",
        "\n",
        "def plot_density_coordinates_interactive(kept_points, color=None):\n",
        "  assert len(kept_points)\n",
        "  df = pd.DataFrame.from_records(kept_points, columns=['x', 'y' ,'z'])\n",
        "  df['color'] = df['z'] * -1\n",
        "  if color is not None:\n",
        "    df['color'] = color\n",
        "  fig = px.scatter_3d(df, x='x', y='y', z='z', color='color', color_continuous_scale=px.colors.sequential.gray)\n",
        "\n",
        "  # Reduce size of each dot.\n",
        "  fig.update_traces(marker={'size': 2})\n",
        "\n",
        "  # Set z-axis to be smaller than x-axis and y-axis.\n",
        "  fig.update_layout(scene_aspectmode='manual',\n",
        "                    scene_aspectratio=dict(x=1, y=1, z=1))\n",
        "\n",
        "  # Update axis direction to match matplotlib.\n",
        "  fig.update_yaxes(autorange=\"reversed\")\n",
        "\n",
        "  return fig\n",
        "\n",
        "MIN_DENSITY = 4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nd58CJ27nmn3"
      },
      "outputs": [],
      "source": [
        "# Generate a 3D lattice of query points.\n",
        "\n",
        "n = 64\n",
        "X = np.linspace(-1, 1, num=n)\n",
        "Y = np.linspace(-1, 1, num=n)\n",
        "Z = np.linspace(-1, 1, num=n)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "plbWUX5Ji2DK"
      },
      "outputs": [],
      "source": [
        "nerf_model = recovered_nerf_state.model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AXyuLbyflnmE"
      },
      "outputs": [],
      "source": [
        "xx, yy, zz = np.meshgrid(X, Y, Z, indexing='ij')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rTiywqJSuoIZ"
      },
      "outputs": [],
      "source": [
        "xx.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "luVY2sUtlvKH"
      },
      "outputs": [],
      "source": [
        "positions = jnp.asarray([[[x_new, y_new, z_new]\n",
        "                          for x_new, y_new, z_new\n",
        "                          in zip(xx.flatten(), yy.flatten(), zz.flatten())]])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q2e2TcxRu0pb"
      },
      "outputs": [],
      "source": [
        "positions.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BLMDkD0yjwUY"
      },
      "outputs": [],
      "source": [
        "p = types.SamplePoints(\n",
        "    scene_id=jnp.asarray([[0]]),\n",
        "    position=positions,\n",
        "    direction=jax.random.uniform(rng.next(), shape=[1, 3]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5wkKlvlui6uz"
      },
      "outputs": [],
      "source": [
        "# Query sigma field across 3D lattice of query points.\n",
        "\n",
        "result = nerf_model.apply(nerf_variables, p)\n",
        "sigma_values = result.sigma\n",
        "color_values = jax.nn.sigmoid(result.rgb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YLwfXoxbu4Na"
      },
      "outputs": [],
      "source": [
        "print(sigma_values.shape, sigma_values.dtype)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "27p3RjhdnHZ0"
      },
      "outputs": [],
      "source": [
        "# Plot sigma field\n",
        "\n",
        "min_density_values = [0, 2, 4, 8, 16, 32, 64]\n",
        "fig = plt.figure(figsize=(len(min_density_values) * 4, 4))\n",
        "axs = []\n",
        "for i, min_density in enumerate(min_density_values):\n",
        "  mask = sigma_values \u003e min_density\n",
        "  kept_points = positions[0][mask[0, :, 0]]\n",
        "  kept_points_color = color_values[0][mask[0, :, 0]]\n",
        "  ax = fig.add_subplot(1, len(min_density_values), i+1, projection='3d')\n",
        "  ax = plot_density_coordinates(kept_points, color=kept_points_color, ax=ax)\n",
        "  ax.set_title(f'min_density = {min_density}')\n",
        "  axs.append(ax)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yIHfhunSFvDR"
      },
      "source": [
        "# Visualize Semantic Field"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tQJxgOjmaYqL"
      },
      "outputs": [],
      "source": [
        "from jax3d.projects.nesf.nerfstatic.models import volumetric_semantic_model\n",
        "from jax3d.projects.nesf.nerfstatic.utils import types\n",
        "from jax3d.projects.nesf.utils.typing import PRNGKey, Tree, f32  # pylint: disable=g-multiple-import"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0Q_KNU8kO68g"
      },
      "outputs": [],
      "source": [
        "rng = jax.random.PRNGKey(params.train.random_seed)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dQoOv5HK53Y1"
      },
      "outputs": [],
      "source": [
        "def plot_3D(points, clusters):\n",
        "  fig = px.scatter_3d(x=points[:, 0],\n",
        "                      y=points[:, 1],\n",
        "                      z=points[:, 2],\n",
        "                      color=clusters,\n",
        "                      )\n",
        "\n",
        "  fig.update_traces(marker=dict(size=1),\n",
        "                    selector=dict(mode='markers'))\n",
        "\n",
        "  fig.update_yaxes(range=[-1,1])\n",
        "  fig.update_xaxes(range=[-1,1])\n",
        "  fig.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JKNjgx7-aFsC"
      },
      "outputs": [],
      "source": [
        "def predict_fn_3d(\n",
        "    rng: PRNGKey,\n",
        "    points: types.SamplePoints,\n",
        "    nerf_variables: Tree[jnp.ndarray],\n",
        "    nerf_sigma_grid: f32[\"1 x y z c\"],\n",
        "    *,\n",
        "    semantic_variables: Tree[jnp.ndarray],\n",
        "    semantic_model: volumetric_semantic_model.VolumetricSemanticModel,\n",
        ") -\u003e f32[\"D n k\"]:\n",
        "  \"\"\"Predict semantic logits for a set of 3D points.\n",
        "\n",
        "  Args:\n",
        "    rng: jax3d random state.\n",
        "    points: 3D points to evaluate. Batch size is 'n'.\n",
        "    nerf_variables: NeRF Model's variables\n",
        "    nerf_sigma_grid: NeRF sigma grid.\n",
        "    semantic_variables: Semantic model variables.\n",
        "    semantic_model: Semantic model for rendering.\n",
        "\n",
        "  Returns:\n",
        "    semantic_logits: Array of shape [D, n, k]. Contains logits for\n",
        "      semantic predictions for each point in 'points' from all devices\n",
        "      participating in this computation. The return value of this\n",
        "      function's dimensions correspond to,\n",
        "        D - number of total devices\n",
        "        n - number of points per device.\n",
        "        k - number of semantic classes.\n",
        "  \"\"\"\n",
        "  rng_names = [\"params\", \"sampling\", \"data_augmentation\"]\n",
        "  rng, *rng_keys = jax.random.split(rng, len(rng_names) + 1)\n",
        "\n",
        "  # Construct dummy rays to render. The current implementation of\n",
        "  # VolumetricSemanticModel requires a set of rays to be provided.\n",
        "\n",
        "  normalize_fn = lambda x: x / jnp.linalg.norm(x, axis=-1, keepdims=True)\n",
        "  n = jax.local_device_count() or 8\n",
        "  dummy_rays = types.Rays(scene_id=jnp.zeros((n, 1), dtype=jnp.int32),\n",
        "                          origin=jnp.zeros((n, 3)),\n",
        "                          direction=normalize_fn(jnp.ones((n, 3))))\n",
        "\n",
        "  _, predictions = semantic_model.apply(\n",
        "      semantic_variables,\n",
        "      rngs=dict(zip(rng_names, rng_keys)),\n",
        "      rays=dummy_rays,\n",
        "      sigma_grid=nerf_sigma_grid,\n",
        "      randomized_sampling=True,\n",
        "      is_train=False,\n",
        "      nerf_model_weights=nerf_variables,\n",
        "      points=points)\n",
        "\n",
        "  return predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yDV5hUlHGWp1"
      },
      "outputs": [],
      "source": [
        "# Create placeholder batch for model initialization.\n",
        "\n",
        "placeholder_batch = dataset.peek()[1]\n",
        "placeholder_batch = jax.tree_map(lambda t: t[0, 0, ...], placeholder_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UkQk3IcfHhN2"
      },
      "outputs": [],
      "source": [
        "# Load pre-trained NeRF model sigma grids and parameters.\n",
        "\n",
        "recovered_nerf_state = semantic_utils.load_all_nerf_variables(\n",
        "    save_dir = params.train.nerf_model_ckpt,\n",
        "    train_dataset = dataset,\n",
        "    novel_dataset = dataset,\n",
        "    recompute_sigma_grid_opts=(\n",
        "        semantic_utils.RecomputeSigmaGridOptions.from_params(params.train)\n",
        "    )\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z8Qx7fsPHnPS"
      },
      "outputs": [],
      "source": [
        "# Initialize semantic model.\n",
        "\n",
        "initialized_vol_sem_model = models.construct_volumetric_semantic_model(\n",
        "    rng=j3d.RandomState(0),\n",
        "    num_scenes=-1,\n",
        "    placeholder_batch=placeholder_batch,\n",
        "    args=params.models,\n",
        "    nerf_model=recovered_nerf_state.model,\n",
        "    nerf_sigma_grid=recovered_nerf_state.train_sigma_grids[0],\n",
        "    nerf_variables=recovered_nerf_state.train_variables[0]\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oWQrlYxcHo4p"
      },
      "outputs": [],
      "source": [
        "vol_sem_model = initialized_vol_sem_model.model\n",
        "semantic_variables = initialized_vol_sem_model.variables\n",
        "\n",
        "optimizer = flax.optim.Adam(params.train.lr_init).create(semantic_variables)\n",
        "state = utils.TrainState(optimizer=optimizer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EDR0V2HYHrxp"
      },
      "outputs": [],
      "source": [
        "# Restore semantic model from checkpoint.\n",
        "\n",
        "save_dir = train_utils.checkpoint_dir(params)\n",
        "state = train_utils.restore_opt_checkpoint(save_dir=save_dir, state=state)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_177kOY3i45y"
      },
      "outputs": [],
      "source": [
        "save_dir"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kn9W396UabQz"
      },
      "outputs": [],
      "source": [
        "# Query semantic model across 3D lattice of points.\n",
        "\n",
        "predictions = predict_fn_3d(rng,\n",
        "                            p,\n",
        "                            recovered_nerf_state.train_variables[0],\n",
        "                            recovered_nerf_state.train_sigma_grids[0],\n",
        "                          #  semantic_variables=semantic_variables,\n",
        "                            semantic_variables=state.optimizer.target,\n",
        "                            semantic_model=vol_sem_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XFQB_jBEbBuD"
      },
      "outputs": [],
      "source": [
        "semantic_predictions = jnp.argmax(predictions, axis=-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XnR_3kvcbiKf"
      },
      "outputs": [],
      "source": [
        "# Visualize semantic model predictions across 3D lattice of points.\n",
        "\n",
        "MIN_DENSITY = 16\n",
        "mask = sigma_values \u003e MIN_DENSITY\n",
        "kept_points = positions[0][mask[0, :, 0]]\n",
        "kept_points_color = semantic_predictions[0][mask[0, :, 0]]\n",
        "print(kept_points.shape)\n",
        "plot_3D(kept_points, kept_points_color)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wU31ic_OvCZj"
      },
      "source": [
        "# Visualize Ground Truth"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3CvJLU81FoEP"
      },
      "outputs": [],
      "source": [
        "# Load ground truth train data examples.\n",
        "\n",
        "examples, _ = klevr.make_examples(data_dir=DATA_DIR, split='train', image_idxs=None, enable_sqrt2_buffer=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LaxIMFUiHMb_"
      },
      "outputs": [],
      "source": [
        "@chex.dataclass\n",
        "class LabeledPointCloud:\n",
        "  points: jnp.ndarray\n",
        "  semantics: jnp.ndarray\n",
        "\n",
        "  @property\n",
        "  def num_points(self):\n",
        "    assert len(self.points.shape) == 2, self.points.shape\n",
        "    assert len(self.semantics.shape) == 2, self.semantics.shape\n",
        "    assert self.points.shape[0] == self.semantics.shape[0]\n",
        "    return self.points.shape[0]\n",
        "\n",
        "def construct_labeled_point_cloud(batch):\n",
        "  \"\"\"Constructs a semantic-labeled point cloud.\"\"\"\n",
        "  ray_o = batch.target_view.rays.origin\n",
        "  ray_d = batch.target_view.rays.direction\n",
        "  depth = batch.target_view.depth\n",
        "\n",
        "  semantics = batch.target_view.semantics\n",
        "\n",
        "  points = ray_o + depth * ray_d\n",
        "\n",
        "  mask = np.all((points \u003e= -1) \u0026 (points \u003c= 1), axis=-1)\n",
        "\n",
        "  select_points = points[mask]\n",
        "  select_semantics = semantics[mask]\n",
        "\n",
        "  return LabeledPointCloud(points=select_points, semantics=select_semantics)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zm02tmEyJwrU"
      },
      "outputs": [],
      "source": [
        "# Construct labeled semantic point cloud from ground truth dataset (i.e. using semantic masks, ray origins \u0026 directions from known cameras, and depth)\n",
        "\n",
        "labeled_point_cloud = construct_labeled_point_cloud(examples)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KuKH2ghlAEXX"
      },
      "outputs": [],
      "source": [
        "labeled_point_cloud.num_points"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aKQp9JknJsBY"
      },
      "outputs": [],
      "source": [
        "idxs = np.random.randint(labeled_point_cloud.num_points, size=200000) # size=50000)\n",
        "mini_point_cloud = jax.tree_map(lambda x: x[idxs], labeled_point_cloud)\n",
        "mini_point_cloud.num_points"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z3xgeLWcA4k7"
      },
      "outputs": [],
      "source": [
        "# Visualize ground truth labeled semantic point cloud.\n",
        "\n",
        "fig = px.scatter_3d(x=mini_point_cloud.points[:, 0],\n",
        "                    y=mini_point_cloud.points[:, 1],\n",
        "                    z=mini_point_cloud.points[:, 2],\n",
        "                    color=mini_point_cloud.semantics[:, 0],)\n",
        "\n",
        "\n",
        "fig.update_traces(marker=dict(size=1),\n",
        "                  selector=dict(mode='markers'))\n",
        "\n",
        "fig.show()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1iroBl_8BBq5yoqmVbEWD_-sHV7u3ToFg",
          "timestamp": 1670466957450
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
